{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## PyTorch-Ignite \n",
    "\n",
    "Setup the runtime to any of your choice:\n",
    "- CPU\n",
    "- GPU\n",
    "- TPU\n",
    "\n",
    "### TL;DR\n",
    "\n",
    "[PyTorch-Ignite](https://github.com/pytorch/ignite) is a high-level library to help with training and evaluating neural networks in PyTorch flexibly and transparently."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Install PyTorch-Ignite\n",
    "!pip install pytorch-ignite"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Setup helper methods\n",
    "\n",
    "- dataflow\n",
    "- model, optimizer, criterion, learning rate scheduler"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "in_colab = \"COLAB_TPU_ADDR\" in os.environ\n",
    "with_torch_launch = \"WORLD_SIZE\" in os.environ\n",
    "\n",
    "if in_colab:\n",
    "    VERSION = !curl -s https://api.github.com/repos/pytorch/xla/releases/latest | grep -Po '\"tag_name\": \"v\\K.*?(?=\")'\n",
    "    !pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-{VERSION[0]}-cp37-cp37m-linux_x86_64.whl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.optim.lr_scheduler import StepLR\n",
    "from torchvision import datasets, models\n",
    "from torchvision.transforms import Compose, Normalize, Pad, RandomCrop, RandomHorizontalFlip, ToTensor\n",
    "\n",
    "import ignite.distributed as idist\n",
    "from ignite.contrib.engines import common\n",
    "from ignite.contrib.handlers import ProgressBar\n",
    "from ignite.engine import Engine, Events, create_supervised_evaluator\n",
    "from ignite.metrics import Accuracy\n",
    "\n",
    "\n",
    "train_transform = Compose(\n",
    "    [\n",
    "        Pad(4),\n",
    "        RandomCrop(32, fill=128),\n",
    "        RandomHorizontalFlip(),\n",
    "        ToTensor(),\n",
    "        Normalize((0.485, 0.456, 0.406), (0.229, 0.23, 0.225)),\n",
    "    ]\n",
    ")\n",
    "\n",
    "test_transform = Compose([ToTensor(), Normalize((0.485, 0.456, 0.406), (0.229, 0.23, 0.225)),])\n",
    "\n",
    "\n",
    "def get_train_test_datasets(path):\n",
    "    train_ds = datasets.CIFAR10(root=path, train=True, download=True, transform=train_transform)\n",
    "    test_ds = datasets.CIFAR10(root=path, train=False, download=False, transform=test_transform)\n",
    "\n",
    "    return train_ds, test_ds\n",
    "\n",
    "\n",
    "def get_model(name):\n",
    "    if name in models.__dict__:\n",
    "        fn = models.__dict__[name]\n",
    "    else:\n",
    "        raise RuntimeError(f\"Unknown model name {name}\")\n",
    "\n",
    "    return fn(num_classes=10)\n",
    "\n",
    "\n",
    "def get_dataflow(config):\n",
    "\n",
    "    # - Get train/test datasets\n",
    "    if idist.get_rank() > 0:\n",
    "        # Ensure that only rank 0 download the dataset\n",
    "        idist.barrier()\n",
    "\n",
    "    train_dataset, test_dataset = get_train_test_datasets(config.get(\"data_path\", \".\"))\n",
    "\n",
    "    if idist.get_rank() == 0:\n",
    "        # Ensure that only rank 0 download the dataset\n",
    "        idist.barrier()\n",
    "\n",
    "    # Setup data loader also adapted to distributed config: nccl, gloo, xla-tpu\n",
    "    train_loader = idist.auto_dataloader(\n",
    "        train_dataset,\n",
    "        batch_size=config.get(\"batch_size\", 512),\n",
    "        num_workers=config.get(\"num_workers\", 8),\n",
    "        shuffle=True,\n",
    "        drop_last=True,\n",
    "    )\n",
    "    config[\"num_iters_per_epoch\"] = len(train_loader)\n",
    "\n",
    "    test_loader = idist.auto_dataloader(\n",
    "        test_dataset,\n",
    "        batch_size=2 * config.get(\"batch_size\", 512),\n",
    "        num_workers=config.get(\"num_workers\", 8),\n",
    "        shuffle=False,\n",
    "    )\n",
    "    return train_loader, test_loader\n",
    "\n",
    "\n",
    "def initialize(config):\n",
    "    model = get_model(config[\"model\"])\n",
    "    # Adapt model for distributed settings if configured\n",
    "    model = idist.auto_model(model)\n",
    "\n",
    "    optimizer = optim.SGD(\n",
    "        model.parameters(),\n",
    "        lr=config.get(\"learning_rate\", 0.1),\n",
    "        momentum=config.get(\"momentum\", 0.9),\n",
    "        weight_decay=config.get(\"weight_decay\", 1e-5),\n",
    "        nesterov=True,\n",
    "    )\n",
    "    optimizer = idist.auto_optim(optimizer)\n",
    "    criterion = nn.CrossEntropyLoss().to(idist.device())\n",
    "\n",
    "    le = config[\"num_iters_per_epoch\"]\n",
    "    lr_scheduler = StepLR(optimizer, step_size=le, gamma=0.9)\n",
    "\n",
    "    return model, optimizer, criterion, lr_scheduler\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Setup your trainer\n",
    "\n",
    "Trainer is defined as Ignite Engine with user's `train_step` logic"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_trainer(model, optimizer, criterion, lr_scheduler, config):\n",
    "\n",
    "    # Define any training logic for iteration update\n",
    "    def train_step(engine, batch):\n",
    "        x, y = batch[0].to(idist.device()), batch[1].to(idist.device())\n",
    "\n",
    "        model.train()\n",
    "        y_pred = model(x)\n",
    "        loss = criterion(y_pred, y)\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        lr_scheduler.step()\n",
    "\n",
    "        return loss.item()\n",
    "\n",
    "    # Define trainer engine\n",
    "    trainer = Engine(train_step)\n",
    "\n",
    "    if idist.get_rank() == 0:\n",
    "        # Add any custom handlers\n",
    "        @trainer.on(Events.ITERATION_COMPLETED(every=200))\n",
    "        def save_checkpoint():\n",
    "            fp = Path(config.get(\"output_path\", \"output\")) / \"checkpoint.pt\"\n",
    "            torch.save(model.state_dict(), fp)\n",
    "\n",
    "        # Add progress bar showing batch loss value\n",
    "        ProgressBar().attach(trainer, output_transform=lambda x: {\"batch loss\": x})\n",
    "\n",
    "    return trainer"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Setup training and validation\n",
    "\n",
    "Assemble all parts together in the training method configured with `config` dictionary :\n",
    "- instantiate dataflow, model, optimizer, criterion, lr scheduler\n",
    "- instantiate trainer, validation engine (`evaluator`)\n",
    "- setup validation phase and print results \n",
    "- setup tensorboard logger\n",
    "- start the training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def training(local_rank, config):\n",
    "\n",
    "    # Setup dataflow and\n",
    "    train_loader, val_loader = get_dataflow(config)\n",
    "    model, optimizer, criterion, lr_scheduler = initialize(config)\n",
    "\n",
    "    # Setup model trainer and evaluator\n",
    "    trainer = create_trainer(model, optimizer, criterion, lr_scheduler, config)\n",
    "    evaluator = create_supervised_evaluator(model, metrics={\"accuracy\": Accuracy()}, device=idist.device())\n",
    "\n",
    "    # Run model evaluation every 3 epochs and show results\n",
    "    @trainer.on(Events.EPOCH_COMPLETED(every=3))\n",
    "    def evaluate_model():\n",
    "        state = evaluator.run(val_loader)\n",
    "        if idist.get_rank() == 0:\n",
    "            print(state.metrics)\n",
    "\n",
    "    # Setup tensorboard experiment tracking\n",
    "    if idist.get_rank() == 0:\n",
    "        tb_logger = common.setup_tb_logging(\n",
    "            config.get(\"output_path\", \"output\"), trainer, optimizer, evaluators={\"validation\": evaluator},\n",
    "        )\n",
    "\n",
    "    trainer.run(train_loader, max_epochs=config.get(\"max_epochs\", 3))\n",
    "\n",
    "    if idist.get_rank() == 0:\n",
    "        tb_logger.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Run on your infrastructure"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext tensorboard\n",
    "%tensorboard --logdir=output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2020-07-24 08:44:07,245 ignite.distributed.launcher.Parallel INFO: - Run '<function training at 0x7f6268c32950>' in 1 processes\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./cifar-10-python.tar.gz\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b71e1d83d024410ea3692ffbf444abfb",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Failed download. Trying https -> http instead. Downloading http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./cifar-10-python.tar.gz\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ce0c2f1d11af49ae85d5c2f657533ef9",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting ./cifar-10-python.tar.gz to .\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2020-07-24 08:44:17,932 ignite.distributed.auto.auto_dataloader INFO: Use data loader kwargs for dataset 'Dataset CIFAR10': \n",
      "\t{'batch_size': 512, 'num_workers': 8, 'shuffle': True, 'drop_last': True, 'pin_memory': True}\n",
      "2020-07-24 08:44:17,933 ignite.distributed.auto.auto_dataloader INFO: Use data loader kwargs for dataset 'Dataset CIFAR10': \n",
      "\t{'batch_size': 1024, 'num_workers': 8, 'shuffle': False, 'pin_memory': True}\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=97.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=97.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=97.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2020-07-24 08:44:45,937 ignite.distributed.launcher.Parallel INFO: End of run\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'accuracy': 0.513}\n"
     ]
    }
   ],
   "source": [
    "# --- Single computation device ---\n",
    "# $ python main.py\n",
    "#\n",
    "if __name__ == \"__main__\" and not (in_colab or with_torch_launch):\n",
    "\n",
    "    backend = None  # or \"nccl\", \"gloo\", \"xla-tpu\" ...\n",
    "    nproc_per_node = None  # or N to spawn N processes\n",
    "    config = {\n",
    "        \"model\": \"resnet18\",\n",
    "        \"dataset\": \"cifar10\",\n",
    "    }\n",
    "\n",
    "    with idist.Parallel(backend=backend, nproc_per_node=nproc_per_node) as parallel:\n",
    "        parallel.run(training, config)\n",
    "\n",
    "\n",
    "# --- Multiple GPUs ---\n",
    "# $ python -m torch.distributed.launch --nproc_per_node=2 --use_env main.py\n",
    "#\n",
    "if __name__ == \"__main__\" and with_torch_launch:\n",
    "\n",
    "    backend = \"nccl\"  # or \"nccl\", \"gloo\", \"xla-tpu\" ...\n",
    "    nproc_per_node = None  # or N to spawn N processes\n",
    "    config = {\n",
    "        \"model\": \"resnet18\",\n",
    "        \"dataset\": \"cifar10\",\n",
    "    }\n",
    "\n",
    "    with idist.Parallel(backend=backend, nproc_per_node=nproc_per_node) as parallel:\n",
    "        parallel.run(training, config)\n",
    "\n",
    "# --- Multiple TPUs ---\n",
    "# In Colab\n",
    "#\n",
    "if in_colab:\n",
    "\n",
    "    backend = \"xla-tpu\"  # or \"nccl\", \"gloo\", \"xla-tpu\" ...\n",
    "    nproc_per_node = 8  # or N to spawn N processes\n",
    "    config = {\n",
    "        \"model\": \"resnet18\",\n",
    "        \"dataset\": \"cifar10\",\n",
    "    }\n",
    "\n",
    "    with idist.Parallel(backend=backend, nproc_per_node=nproc_per_node) as parallel:\n",
    "        parallel.run(training, config)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Other links\n",
    "\n",
    "- Full featured CIFAR10 example: https://github.com/pytorch/ignite/tree/master/examples/contrib/cifar10\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
